Compare DMS to natural sequence evolution¶

In [1]:
# this cell is tagged parameters for papermill parameterization
dms_summary_csv = None
growth_rates_csv = None
pango_consensus_seqs_json = None
starting_clades = None
dms_clade = None
n_random = None
exclude_clades = None
pango_dms_phenotypes_csv = None
pango_by_date_html = None
pango_affinity_vs_escape_html = None
pango_dms_vs_growth_regression_html = None
pango_dms_vs_growth_regression_by_domain_html = None
pango_dms_vs_growth_corr_html = None
pango_dms_vs_growth_corr_by_domain_html = None
exclude_clades_with_muts = None
In [2]:
# Parameters
starting_clades = ["BA.2", "BA.5", "XBB"]
dms_clade = "XBB.1.5"
dms_summary_csv = "results/summaries/summary.csv"
growth_rates_csv = "MultinomialLogisticGrowth/model_fits/rates.csv"
pango_consensus_seqs_json = (
    "results/compare_natural/pango-consensus-sequences_summary.json"
)
pango_dms_phenotypes_csv = "results/compare_natural/pango_dms_phenotypes.csv"
pango_by_date_html = "results/compare_natural/pango_dms_phenotypes_by_date.html"
pango_affinity_vs_escape_html = "results/compare_natural/pango_affinity_vs_escape.html"
pango_dms_vs_growth_regression_html = (
    "results/compare_natural/pango_dms_vs_growth_regression.html"
)
pango_dms_vs_growth_regression_by_domain_html = (
    "results/compare_natural/pango_dms_vs_growth_regression_by_domain.html"
)
pango_dms_vs_growth_corr_html = "results/compare_natural/pango_dms_vs_growth_corr.html"
pango_dms_vs_growth_corr_by_domain_html = (
    "results/compare_natural/pango_dms_vs_growth_corr_by_domain.html"
)
n_random = 200
exclude_clades = []
exclude_clades_with_muts = []
In [3]:
import collections
import itertools
import json
import math
import re

import altair as alt

import numpy

import pandas as pd

import polyclonal.plot

import scipy.stats

import statsmodels.api

_ = alt.data_transformers.disable_max_rows()

Read Pango clades and mutations¶

In [4]:
with open(pango_consensus_seqs_json) as f:
    pango_clades = json.load(f)

def n_child_clades(c):
    """Get number of children clades of a Pango clade."""
    direct_children = pango_clades[c]["children"]
    return len(direct_children) + sum([n_child_clades(c_child) for c_child in direct_children])

def build_records(c, recs):
    """Build records of Pango clade information."""
    if c in recs["clade"]:
        return
    recs["clade"].append(c)
    recs["n_child_clades"].append(n_child_clades(c))
    recs["date"].append(pango_clades[c]["designationDate"])
    recs["muts_from_ref"].append(
        [
            mut.split(":")[1]
            for field in ["aaSubstitutions", "aaDeletions"]
            for mut in pango_clades[c][field]
            if mut.startswith("S:")
        ]
    )
    for c_child in pango_clades[c]["children"]:
        build_records(c_child, recs)
        
records = collections.defaultdict(list)
for starting_clade in starting_clades:
    build_records(starting_clade, records)

pango_df = pd.DataFrame(records).query("clade not in @exclude_clades")
dms_clade_mutations_from_ref = pango_df.set_index("clade").at[
    dms_clade, "muts_from_ref"
]

def mutations_from(muts, from_muts):
    """Get mutations from another sequence."""
    new_muts = set(muts).symmetric_difference(from_muts)
    assert all(re.fullmatch("[A-Z\-]\d+[A-Z\-]", m) for m in new_muts)
    new_muts_d = collections.defaultdict(list)
    for m in new_muts:
        new_muts_d[int(m[1: -1])].append(m)
    new_muts_list = []
    for _, ms in sorted(new_muts_d.items()):
        if len(ms) == 1:
            m = ms[0]
            if m in muts:
                new_muts_list.append(m)
            else:
                assert m in from_muts
                new_muts_list.append(m[-1] + m[1: -1] + m[0])
        else:
            m, from_m = ms
            if m not in muts:
                from_m, m = m, from_m
            assert m in muts and from_m in from_muts
            new_muts_list.append(from_m[-1] + m[1: ])
    return new_muts_list

pango_df = (
    pango_df
    .assign(
        muts_from_dms_clade=lambda x: x["muts_from_ref"].apply(
            mutations_from, args=(dms_clade_mutations_from_ref,),
        ),
        date=lambda x: pd.to_datetime(x["date"]),
    )
    .drop(columns="muts_from_ref")
    .sort_values("date")
    .reset_index(drop=True)
)

for mut in exclude_clades_with_muts:
    pango_df = pango_df[pango_df["muts_from_dms_clade"].map(lambda ms: mut not in ms)]

pango_df
Out[4]:
clade n_child_clades date muts_from_dms_clade
0 BA.2 384 2021-12-07 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
1 BA.2.1 0 2022-02-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
2 BA.2.2 1 2022-02-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
3 BA.2.3 52 2022-02-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
4 BA.2.7 0 2022-03-25 [A83V, -144Y, Q146H, E183Q, E213G, V252G, H339...
... ... ... ... ...
1575 GJ.1.2.7 0 2023-10-04 [K182N, V252G, D253G, K478R, P521S]
1576 GJ.1.2.8 0 2023-10-04 [K182N, V252G, D253G, P521S, T747I]
1577 GJ.5 1 2023-10-04 [K182N, V252G, D253G, K478R, P521S]
1578 GJ.5.1 0 2023-10-04 [K182N, V252G, D253G, S255F, K478R, P521S]
1579 JD.2 0 2023-10-04 []

1580 rows × 4 columns

Assign DMS phenotypes to Pango clades¶

First define function that assigns DMS phenotypes to mutations:

In [5]:
# read the DMS data
dms_summary = pd.read_csv(dms_summary_csv).rename(
    columns={
        "spike mediated entry": "cell entry",
        "human sera escape": "sera escape",
    }
)

# specify DMS phenotypes of interest
phenotypes = [
    "sera escape",
    "ACE2 affinity",
    "cell entry",
]
assert set(phenotypes).issubset(dms_summary.columns)

phenotype_colors = {
    "sera escape": "red",
    "ACE2 affinity": "blue",
    "cell entry": "purple",
}
assert set(phenotypes) == set(phenotype_colors)


# dict that maps site to wildtype in DMS
dms_wt = dms_summary.set_index("site")["wildtype"].to_dict()

# dict that maps site to region in DMS
site_to_region = dms_summary.set_index("site")["region"].to_dict()

def mut_dms(m, dms_data):
    """Get DMS phenotypes for a mutation."""
    null_d = {k: pd.NA for k in phenotypes}
    if pd.isnull(m) or int(m[1: -1]) not in dms_wt:
        d = null_d
        d["is_RBD"] = pd.NA
    else:
        parent = m[0]
        site = int(m[1: -1])
        mut = m[-1]
        wt = dms_wt[site]
        if parent == wt:
            try:
                d = dms_data[(site, parent, mut)]
            except KeyError:
                d = null_d
        elif mut == wt:
            try:
                d = {k: -v for (k, v) in dms_data[(site, mut, parent)].items()}
            except KeyError:
                d = null_d
        else:
            try:
                parent_d = dms_data[(site, wt, parent)]
                mut_d = dms_data[(site, wt, mut)]
                d = {p: mut_d[p] - parent_d[p] for p in phenotypes}
            except KeyError:
                d = null_d
        d["is_RBD"] = (site_to_region[site] == "RBD")
    assert list(d) == phenotypes + ["is_RBD"]
    return d

Now assign phenotypes to pango clades. We do this both using the actual DMS data and randomizing the DMS data among measured mutations:

In [6]:
def get_pango_dms_df(dms_data_dict):
    """Given dict mapping mutations to DMS data, get data frame of values for Pango clades."""
    pango_dms_df = (
        pango_df
        # put one mutation in each column
        .explode("muts_from_dms_clade")
        .rename(columns={"muts_from_dms_clade": "mutation"})
        # to add multiple columns: https://stackoverflow.com/a/46814360
        .apply(
            lambda cols: pd.concat([cols, pd.Series(mut_dms(cols["mutation"], dms_data_dict))]),
            axis=1,
        )
        .melt(
            id_vars=["clade", "date", "n_child_clades", "mutation", "is_RBD"],
            value_vars=phenotypes,
            var_name="DMS_phenotype",
            value_name="mutation_effect",
        )
        .assign(
            muts_from_dms_clade=lambda x: x.groupby(["clade", "DMS_phenotype"])["mutation"].transform(
                lambda ms: "; ".join([m for m in ms if not pd.isnull(m)])
            ),
            mutation_missing=lambda x: x["mutation"].where(
                x["mutation_effect"].isnull() & x["mutation"].notnull(),
                pd.NA,
            ),
            muts_from_dms_clade_missing_data=lambda x: (
                x.groupby(["clade", "DMS_phenotype"])["mutation_missing"]
                .transform(lambda ms: "; ".join([m for m in ms if not pd.isnull(m)]))
            ),
            mutation_effect=lambda x: x["mutation_effect"].fillna(0),
            is_RBD=lambda x: x["is_RBD"].fillna(False),
            mutation_effect_RBD=lambda x: x["mutation_effect"] * x["is_RBD"].astype(int),
            mutation_effect_nonRBD=lambda x: x["mutation_effect"] * (~x["is_RBD"]).astype(int),
        )
        .groupby(
            [
                "clade",
                "date",
                "n_child_clades",
                "muts_from_dms_clade",
                "muts_from_dms_clade_missing_data",
                "DMS_phenotype",
            ],
            as_index=False,
        )
        .aggregate(
            phenotype=pd.NamedAgg("mutation_effect", "sum"),
            phenotype_RBD_only=pd.NamedAgg("mutation_effect_RBD", "sum"),
            phenotype_nonRBD_only=pd.NamedAgg("mutation_effect_nonRBD", "sum"),
        )
        .rename(
            columns={
                "muts_from_dms_clade": f"muts_from_{dms_clade}",
                "muts_from_dms_clade_missing_data": f"muts_from_{dms_clade}_missing_data",
            },
        )
        .sort_values(["date", "DMS_phenotype"])
        .reset_index(drop=True)
    )
    
    assert set(pango_df["clade"]) == set(pango_dms_df["clade"])
    assert numpy.allclose(
        pango_dms_df["phenotype"],
        pango_dms_df["phenotype_RBD_only"] + pango_dms_df["phenotype_nonRBD_only"]
    )

    return pango_dms_df

# First, get the actual DMS data mapped to phenotype
dms_data_dict_actual = (
    dms_summary
    .set_index(["site", "wildtype", "mutant"])
    [phenotypes]
    .to_dict(orient="index")
)
pango_dms_df = get_pango_dms_df(dms_data_dict_actual)
print(f"Saving Pango DMS phenotypes to {pango_dms_phenotypes_csv}")
pango_dms_df.to_csv(pango_dms_phenotypes_csv, float_format="%.4f", index=False)

# Now get the randomized DMS data mapped to phenotype
pango_dms_dfs_rand = []
numpy.random.seed(0)
for irandom in range(1, n_random + 1):
    # randomize the non-null DMS data for each phenotype
    dms_summary_rand = dms_summary.copy()
    for phenotype in phenotypes:
        dms_summary_rand = dms_summary_rand.assign(
            **{phenotype: lambda x: numpy.random.permutation(x[phenotype].values)}
        )
    dms_data_dict_rand = (
        dms_summary_rand
        .set_index(["site", "wildtype", "mutant"])
        [phenotypes]
        .to_dict(orient="index")
    )
    pango_dms_dfs_rand.append(get_pango_dms_df(dms_data_dict_rand).assign(randomize=irandom))
# all randomizations concatenated
pango_dms_df_rand = pd.concat(pango_dms_dfs_rand)
Saving Pango DMS phenotypes to results/compare_natural/pango_dms_phenotypes.csv

Plot phenotypes of Pango clades¶

Plot phenotypes of Pango clades versus their designation dates:

In [7]:
region_cols = {
    "phenotype": "full spike",
    "phenotype_RBD_only": "RBD only",
    "phenotype_nonRBD_only": "non-RBD only",
}

pango_chart_df = (
    pango_dms_df
    .melt(
        id_vars=[c for c in pango_dms_df if c not in region_cols],
        value_vars=region_cols,
        var_name="spike_region",
        value_name="phenotype value",
    )
    .assign(
        spike_region=lambda x: x["spike_region"].map(region_cols),
    )
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
)

# columns cannot have "." in them for Altair
col_renames = {c: c.replace(".", "_") for c in pango_chart_df.columns if "." in c}
col_renames_rev = {v: k for (k, v) in col_renames.items()}
pango_chart_df = pango_chart_df.rename(columns=col_renames)

clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

base_pango_chart = (
    alt.Chart(pango_chart_df)
    .encode(
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_chart_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(60), alt.value(40)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
        color=alt.Color(
            "DMS_phenotype",
            legend=None,
            scale=alt.Scale(
                range=list(phenotype_colors.values()),
                domain=list(phenotype_colors.keys()),
            ),
        ),
    )
    .mark_circle(stroke="black")
    .properties(width=300, height=125)
)

phenotype_pango_charts = []
for phenotype in phenotypes:
    first_row = (phenotype == phenotypes[0])
    last_row = (phenotype == phenotypes[-1])
    phenotype_pango_charts.append(
        base_pango_chart
        .transform_filter(alt.datum["DMS_phenotype"] == phenotype)
        .encode(
            x=alt.X(
                "date",
                title="designation date of clade" if last_row else None,
                axis=(
                    alt.Axis(titleFontSize=12, labelOverlap=True, format="%b-%Y", labelAngle=-90)
                    if last_row
                    else None
                ),
                scale=alt.Scale(nice=False, padding=3),
            ),
            y=alt.Y(
                "phenotype value",
                title=phenotype,
                axis=alt.Axis(titleFontSize=12),
                scale=alt.Scale(nice=False, padding=3),
            ),
            column=alt.Column(
                "spike_region",
                sort=list(region_cols),
                title=None,
                header=(
                    alt.Header(labelFontSize=12, labelFontStyle="bold", labelPadding=4)
                    if first_row
                    else None
                ),
                spacing=4,
            ),
        )
    )

pango_chart = (
    alt.vconcat(*phenotype_pango_charts, spacing=4)
    .configure_axis(grid=False)
    .add_params(clade_selection)
    .properties(        
        title=alt.TitleParams(
            f"DMS predicted phenotypes of Pango clades descended from {', '.join(starting_clades)}",
            anchor="middle",
            fontSize=16,
            dy=-5,
        ),
    )
)

print(f"Saving chart to {pango_by_date_html}")
pango_chart.save(pango_by_date_html)

pango_chart
Saving chart to results/compare_natural/pango_dms_phenotypes_by_date.html
Out[7]:

Pango clade affinity versus escape scatter plot¶

In [8]:
pango_scatter_df = (
    pango_dms_df
    .pivot_table(
        index=[
            c
            for c in pango_dms_df
            if c not in {"DMS_phenotype", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"}
        ],
        values="phenotype",
        columns="DMS_phenotype",
    )
    .reset_index()
    .rename(columns={f"muts_from_{dms_clade}_missing_data": "muts_missing_data"})
    .rename(columns=col_renames)
)

pango_scatter_df

pango_scatter_chart = (
    alt.Chart(pango_scatter_df)
    .encode(
        x=alt.X(
            "ACE2 affinity",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        y=alt.Y(
            "sera escape",
            axis=alt.Axis(titleFontSize=12),
            scale=alt.Scale(nice=False, padding=5),
        ),
        tooltip=[
            alt.Tooltip(c, title=col_renames_rev[c] if c in col_renames_rev else c)
            for c in pango_scatter_df.columns
        ],
        opacity=alt.condition(clade_selection, alt.value(1), alt.value(0.35)),
        size=alt.condition(clade_selection, alt.value(100), alt.value(55)),
        strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0)),
    )
    .mark_circle(stroke="red", color="black")
    .add_params(clade_selection)
    .configure_axis(grid=False)
    .properties(        
        title=alt.TitleParams(
            [
                "DMS predicted ACE2 affinity vs serum escape",
                f"for Pango clades descended from {starting_clade}"
            ],
            anchor="middle",
            fontSize=14,
            dy=-5,
        ),
    )
    .properties(width=300, height=300)
)

print(f"Saving chart to {pango_affinity_vs_escape_html}")
pango_scatter_chart.save(pango_affinity_vs_escape_html)

pango_scatter_chart
Saving chart to results/compare_natural/pango_affinity_vs_escape.html
Out[8]:

Correlate with clade growth¶

In [9]:
growth_rates = pd.read_csv(growth_rates_csv).rename(
    columns={"pango": "clade", "seq_volume": "number sequences"}
)

if (invalid_clades := set(growth_rates["clade"]) - set(pango_clades)):
    raise ValueError(f"Growth rates specified for {invalid_clades}")

pango_dms_growth_df = pango_dms_df.merge(growth_rates, on="clade", validate="many_to_one")

pango_dms_growth_df_rand = pango_dms_df_rand.merge(growth_rates, on="clade", validate="many_to_one")

print(
    f"{growth_rates['clade'].nunique()} clades have growth rates estimates.\n"
    f"{pango_dms_df['clade'].nunique()} clades have DMS estimates.\n"
    f"{pango_dms_growth_df['clade'].nunique()} clades have growth and DMS estimates"
)

print("Simple correlations:")
display(
    pango_dms_growth_df
    .groupby("DMS_phenotype")
    [["R", "phenotype", "phenotype_RBD_only", "phenotype_nonRBD_only"]]
    .corr()
    [["R"]]
)
990 clades have growth rates estimates.
1580 clades have DMS estimates.
923 clades have growth and DMS estimates
Simple correlations:
R
DMS_phenotype
ACE2 affinity R 1.000000
phenotype -0.488450
phenotype_RBD_only -0.309932
phenotype_nonRBD_only -0.331831
cell entry R 1.000000
phenotype 0.788332
phenotype_RBD_only 0.812665
phenotype_nonRBD_only 0.454013
sera escape R 1.000000
phenotype 0.930667
phenotype_RBD_only 0.927965
phenotype_nonRBD_only 0.359769

Plot number of sequences versus date, with sizes proportional to log of number of sequences in clade:

In [10]:
(
    alt.Chart(pango_dms_growth_df)
    .encode(
        x="date",
        y="R",
        size=alt.Size("number sequences", scale=alt.Scale(type="log")),
        tooltip=pango_dms_growth_df.columns.tolist(),
    )
    .mark_circle(opacity=0.25, color="black")
)
Out[10]:

Now perform OLS, weighting clades by log number of sequences:

In [11]:
# pivot DMS data to get phenotypes
def pivot_for_ols_vars(df):
    ols_vars = (
        df
        .rename(
            columns={
                "phenotype": "full spike",
                "phenotype_RBD_only": "RBD",
                "phenotype_nonRBD_only": "non RBD",
            }
        )
        .assign(
            # group muts missing data from all phenotypes
            muts_from_DMS_clade_missing_data=lambda x: (
                x.groupby("clade")
                [f"muts_from_{dms_clade}_missing_data"]
                .transform(
                    lambda s: "; ".join(dict.fromkeys([m for ms in s.str.split("; ") for m in ms if m]))
                )
            ),
        )
        .rename(columns={f"muts_from_{dms_clade}": "muts_from_DMS_clade"})
        .pivot_table(
            index=[
                "clade",
                "R",
                "date",
                "muts_from_DMS_clade",
                "muts_from_DMS_clade_missing_data",
                "number sequences",
            ],
            columns="DMS_phenotype",
            values=["full spike", "RBD", "non RBD"],
        )
    )
    # flatten column names
    assert all(len(c) == 2 for c in ols_vars.columns.values)
    ols_vars.columns = [f"{pheno} ({domain})" for domain, pheno in ols_vars.columns.values]
    return ols_vars.reset_index()

ols_vars = pivot_for_ols_vars(pango_dms_growth_df)

# https://www.einblick.ai/python-code-examples/ordinary-least-squares-regression-statsmodels/
for name, exog_vars, regression_chartfile, corr_chartfile in [
    (
        "full spike",
        [f"{c} (full spike)" for c in phenotypes],
        pango_dms_vs_growth_regression_html,
        pango_dms_vs_growth_corr_html
    ),
    (
        "separate RBD and non-RBD",
        [f"{c} ({d})" for d in ["RBD", "non RBD"] for c in phenotypes],
        pango_dms_vs_growth_regression_by_domain_html,
        pango_dms_vs_growth_corr_by_domain_html,
    ),
]:
    print(f"\n\nFitting for {name}:")
    ols_model = statsmodels.api.WLS(
        endog=ols_vars[["R"]],
        exog=statsmodels.api.add_constant(ols_vars[exog_vars]),
        # weight by log n sequences, so pass log**2
        weights=numpy.log(ols_vars["number sequences"])**2,
    )
    res_ols = ols_model.fit()
    display(res_ols.summary())

    fitted_df = ols_vars.assign(DMS_predicted_growth=res_ols.predict())

    plot_size=180
    
    clade_selection = alt.selection_point(fields=["clade"], on="mouseover", empty=False)

    n_sequences_init = int(10 * math.log10(fitted_df["number sequences"].min())) / 10
    n_sequences_slider = alt.param(
        value=n_sequences_init,
        bind=alt.binding_range(
            name="minimum log10 number sequences in clade",
            min=n_sequences_init,
            max=math.log10(fitted_df["number sequences"].max() / 10),
        ),
    )

    # date slider: https://stackoverflow.com/a/67941109
    select_date = alt.selection_interval(encodings=["x"])
    date_slider = (
        alt.Chart(fitted_df[["clade", "date"]].drop_duplicates())
        .mark_bar(color="black")
        .encode(
            x=alt.X(
                "date",
                title="zoom bar to select clades by designation date",
                axis=alt.Axis(format="%b-%Y"),
            ),
            y=alt.Y("count()", title=["number", "clades"]),
        )
        .properties(width=1.5 * plot_size, height=45)
        .add_params(select_date)
    )
    
    base_growth_chart = (
        alt.Chart(fitted_df)
        .transform_filter(
            alt.expr.log(alt.datum["number sequences"]) / math.log(10) >= n_sequences_slider
        )
        .transform_filter(select_date)
        .encode(
            size=alt.Size(
                "number sequences",
                scale=alt.Scale(
                    type="log",
                    nice=False,
                    range=[15, 250],
                ),
                legend=alt.Legend(symbolStrokeWidth=0, symbolFillColor="gray"),
            ),
            strokeWidth=alt.condition(clade_selection, alt.value(2), alt.value(0.5)),
            strokeOpacity=alt.condition(clade_selection, alt.value(1), alt.value(0.5)),
            tooltip=[
                "clade",
                alt.Tooltip("R", title="growth rate (R)", format=".1f"),
                alt.Tooltip("DMS_predicted_growth", title="DMS predicted growth", format=".1f"),
                alt.Tooltip("number sequences", format=".2g"),
                alt.Tooltip("date", title="designation date"),
                alt.Tooltip("muts_from_DMS_clade", title=f"muts from {dms_clade}"),
                alt.Tooltip("muts_from_DMS_clade_missing_data", title="muts missing DMS data"),
                *[alt.Tooltip(v, format=".2f") for v in exog_vars],  
            ],
        )
        .properties(width=plot_size, height=plot_size)
        .add_params(clade_selection, n_sequences_slider)
    )

    growth_charts = []
    simple_corr_charts = []
    for i, (dms_pheno, pheno) in enumerate(zip(
        exog_vars,
        itertools.cycle(phenotypes)
    )):
        assert dms_pheno.startswith(pheno)
        base_pheno_chart = (
            base_growth_chart
            .encode(
                y=alt.Y(
                    "R",
                    title="actual clade growth rate (R)",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                    axis=None if i % len(phenotypes) else alt.Axis(),
                ),
            )
        )

        growth_charts.append(
            base_pheno_chart
            .encode(
                x=alt.X(
                    "DMS_predicted_growth",
                    title="DMS predicted clade growth",
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    dms_pheno,
                    title=None,
                    legend=alt.Legend(
                        orient="top",
                        titleFontSize=12,
                        gradientLength=plot_size,
                        gradientThickness=10,
                        offset=5,
                        tickCount=3,
                    ),
                    scale=alt.Scale(
                        range=polyclonal.plot.color_gradient_hex("lightgray", phenotype_colors[pheno], 40),
                        nice=False,
                    ),
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.6)
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=(
                        f"coefficient: {res_ols.params[dms_pheno]:.1f} "
                        # https://stackoverflow.com/a/53966201
                        + f"\u00B1 {res_ols.bse[dms_pheno]:.1f}, "
                        + f"P: {res_ols.pvalues[dms_pheno]:.1g}"
                    ),
                    subtitleFontSize=11,
                ),
            )
        )

        pheno_r, pheno_p = scipy.stats.pearsonr(fitted_df["R"], fitted_df[dms_pheno])        
        simple_corr_charts.append(
            base_pheno_chart
            .transform_calculate(color_phenotype=f"'{pheno}'")
            .encode(
                x=alt.X(
                    dms_pheno,
                    scale=alt.Scale(nice=False, padding=5, zero=False),
                ),
                color=alt.Color(
                    "color_phenotype:N",
                    scale=alt.Scale(
                        range=list(phenotype_colors.values()),
                        domain=list(phenotype_colors.keys()),
                    ),
                    legend=None,
                ),
            )
            .mark_circle(stroke="black", fillOpacity=0.3, color=phenotype_colors[pheno])
            .properties(
                title=alt.TitleParams(
                    text=dms_pheno,
                    subtitle=f"Pearson r: {pheno_r:.2f}",
                    subtitleFontSize=11,
                ),
            )
        )
            
    actual_r = math.sqrt(res_ols.rsquared)
    assert len(growth_charts) % len(phenotypes) == 0
    growth_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *growth_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    ).resolve_scale(color="independent")
                    for i in range(len(growth_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                f"Weighted linear regression of DMS phenotypes vs clade growth (Pearson r = {actual_r:.2f})",
                anchor="middle",
                fontSize=14,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )

    simple_corr_chart = (
        alt.vconcat(
            alt.vconcat(
                *[
                    alt.hconcat(
                        *simple_corr_charts[i * len(phenotypes): (i + 1) * len(phenotypes)], spacing=13
                    )
                    for i in range(len(simple_corr_charts) // len(phenotypes))
                ],
                spacing=13,
            ),
            date_slider,
        )
        .properties(
            title=alt.TitleParams(
                "Simple correlations of DMS phenotypes vs clade growth",
                anchor="middle",
                fontSize=14,
                dy=-5,
            ),
        )
        .configure_axis(grid=False)
    )
    
    display(growth_chart)
    print(f"Saving to {regression_chartfile}")
    growth_chart.save(regression_chartfile)

    display(simple_corr_chart)
    print(f"Saving to {corr_chartfile}")
    simple_corr_chart.save(corr_chartfile)

    # fit randomized models and compute P-value based on R values
    print("Computing P-value from randomizations")
    rand_r = []
    for randomseed, rand_df in pango_dms_growth_df_rand.groupby("randomize"):
        rand_ols_vars = pivot_for_ols_vars(rand_df)
        rand_ols_model = statsmodels.api.WLS(
            endog=rand_ols_vars[["R"]],
            exog=statsmodels.api.add_constant(rand_ols_vars[exog_vars]),
            # weight by log n sequences, so pass log**2
            weights=numpy.log(rand_ols_vars["number sequences"])**2,
        )
        rand_res_ols = rand_ols_model.fit()
        rand_r.append(math.sqrt(rand_res_ols.rsquared))
    n_rand_ge = sum(r >= actual_r for r in rand_r)
    pval = f"= {n_rand_ge / len(rand_r)}" if n_rand_ge else f"< {1 / len(rand_r)}"
    
    rand_r_hist = (
        alt.Chart(pd.DataFrame({"r": rand_r}))
        .encode(
            x=alt.X(
                "r",
                title="Pearson r",
                bin=alt.BinParams(step=0.02, extent=(0, 1)),
                scale=alt.Scale(domain=(0, 1)),
                axis=alt.Axis(values=[0, 0.2, 0.4, 0.6, 0.8, 1]),
            ),
            y=alt.Y("count()", title="number of randomizations"),
        )
        .mark_bar(color="black", opacity=0.65, align="right")
        .properties(width=250, height=130)
    )
    
    actual_r_line = (
        alt.Chart(pd.DataFrame({"r": [actual_r]}))
        .encode(x="r")
        .mark_rule(size=2, color="red", strokeDash=[4, 2])
    )
    
    pval_chart = (
        (rand_r_hist + actual_r_line)
        .configure_axis(grid=False)
        .properties(
            title=alt.TitleParams(
                f"P {pval}",
                subtitle=f"{n_rand_ge} of {len(rand_r)} randomizations 	\u2265 actual r of {actual_r:.2f}",
            ),
        )
    )
    
    display(pval_chart)

Fitting for full spike:
WLS Regression Results
Dep. Variable: R R-squared: 0.883
Model: WLS Adj. R-squared: 0.883
Method: Least Squares F-statistic: 2309.
Date: Sun, 08 Oct 2023 Prob (F-statistic): 0.00
Time: 13:06:37 Log-Likelihood: -3613.1
No. Observations: 923 AIC: 7234.
Df Residuals: 919 BIC: 7254.
Df Model: 3
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
const 32.6462 0.730 44.692 0.000 31.213 34.080
sera escape (full spike) 24.3328 0.570 42.684 0.000 23.214 25.452
ACE2 affinity (full spike) 4.1776 1.260 3.314 0.001 1.704 6.651
cell entry (full spike) 13.4457 2.225 6.043 0.000 9.079 17.812
Omnibus: 27.151 Durbin-Watson: 0.816
Prob(Omnibus): 0.000 Jarque-Bera (JB): 35.388
Skew: 0.315 Prob(JB): 2.07e-08
Kurtosis: 3.724 Cond. No. 12.5


Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Saving to results/compare_natural/pango_dms_vs_growth_regression.html
Saving to results/compare_natural/pango_dms_vs_growth_corr.html
Computing P-value from randomizations

Fitting for separate RBD and non-RBD:
WLS Regression Results
Dep. Variable: R R-squared: 0.891
Model: WLS Adj. R-squared: 0.890
Method: Least Squares F-statistic: 1243.
Date: Sun, 08 Oct 2023 Prob (F-statistic): 0.00
Time: 13:07:13 Log-Likelihood: -3581.5
No. Observations: 923 AIC: 7177.
Df Residuals: 916 BIC: 7211.
Df Model: 6
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
const 33.5352 0.731 45.883 0.000 32.101 34.970
sera escape (RBD) 29.3391 0.848 34.608 0.000 27.675 31.003
ACE2 affinity (RBD) 3.9192 1.357 2.888 0.004 1.256 6.582
cell entry (RBD) -18.1338 4.472 -4.055 0.000 -26.909 -9.358
sera escape (non RBD) 40.3985 4.835 8.355 0.000 30.909 49.888
ACE2 affinity (non RBD) 9.9229 1.996 4.972 0.000 6.006 13.840
cell entry (non RBD) 23.1634 2.956 7.836 0.000 17.362 28.965
Omnibus: 44.555 Durbin-Watson: 0.908
Prob(Omnibus): 0.000 Jarque-Bera (JB): 92.936
Skew: 0.297 Prob(JB): 6.59e-21
Kurtosis: 4.436 Cond. No. 27.5


Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.
Saving to results/compare_natural/pango_dms_vs_growth_regression_by_domain.html
Saving to results/compare_natural/pango_dms_vs_growth_corr_by_domain.html
Computing P-value from randomizations

Distributions of DMS mutation effects in clades with growth estimates versus all mutations¶

In [12]:
muts_in_clades = collections.Counter(
    pango_dms_growth_df
    [f"muts_from_{dms_clade}"]
    .pipe(lambda s: s[s != ""])
    .str.split("; ")
    .explode()
)
print(f"There are {len(muts_in_clades)} mutations found in clades with growth estimates")

all_muts_dms = (
    dms_summary
    .query("wildtype != mutant")
    .assign(mutation=lambda x: x["wildtype"] + x["site"].astype(str) + x["mutant"])
    .assign(region=lambda x: x["region"].where(x["region"] == "RBD", "non RBD"))
    .melt(
        id_vars=["mutation", "region"],
        value_vars=phenotypes,
        var_name="DMS_phenotype",
        value_name="phenotype",
    )
    .query("phenotype.notnull()") 
)

all_muts_dms = pd.concat(
    [
        all_muts_dms.assign(mutation_type="any", count=1),
        all_muts_dms.query("mutation in @muts_in_clades").assign(
            mutation_type="in Pango clade",
            count=lambda x: x["mutation"].map(muts_in_clades),
        ),
    ]
)

for pheno in phenotypes:
    
    base_hist = (
        alt.Chart(
            all_muts_dms
            .query("DMS_phenotype == @pheno")
            .drop(columns=["DMS_phenotype", "mutation"])
        )
        .encode(
            x=alt.X("phenotype", bin=alt.BinParams(maxbins=50)),
            y=alt.Y("sum(count)", title="mutations"),
            color=alt.value(phenotype_colors[pheno]),
            row=alt.Row("mutation_type", title=None, spacing=5),
        )
        .properties(width=200, height=75, title=pheno)
        .mark_bar()
        .resolve_scale(y="independent")
    )
    display(base_hist)
There are 278 mutations found in clades with growth estimates